source("function_import.R")
unable to open connection to X11 display ''
processing file: 00-project-overview.Rmd
|
| | 0%
|
|...... | 9%
|
|............ | 18%
|
|.................. | 27%
|
|........................ | 36%
|
|.............................. | 45%
|
|................................... | 55%
|
|......................................... | 64%
|
|............................................... | 73%
|
|..................................................... | 82%
|
|........................................................... | 91%
|
|.................................................................| 100%
output file: 00-project-overview.R
unable to open connection to X11 display ''
processing file: 10-import-data.Rmd
|
| | 0%
|
|. | 2%
|
|.. | 3%
|
|... | 5%
|
|.... | 7%
|
|..... | 8%
|
|...... | 10%
|
|....... | 11%
|
|......... | 13%
|
|.......... | 15%
|
|........... | 16%
|
|............ | 18%
|
|............. | 20%
|
|.............. | 21%
|
|............... | 23%
|
|................ | 25%
|
|................. | 26%
|
|.................. | 28%
|
|................... | 30%
|
|.................... | 31%
|
|..................... | 33%
|
|...................... | 34%
|
|....................... | 36%
|
|......................... | 38%
|
|.......................... | 39%
|
|........................... | 41%
|
|............................ | 43%
|
|............................. | 44%
|
|.............................. | 46%
|
|............................... | 48%
|
|................................ | 49%
|
|................................. | 51%
|
|.................................. | 52%
|
|................................... | 54%
|
|.................................... | 56%
|
|..................................... | 57%
|
|...................................... | 59%
|
|....................................... | 61%
|
|........................................ | 62%
|
|.......................................... | 64%
|
|........................................... | 66%
|
|............................................ | 67%
|
|............................................. | 69%
|
|.............................................. | 70%
|
|............................................... | 72%
|
|................................................ | 74%
|
|................................................. | 75%
|
|.................................................. | 77%
|
|................................................... | 79%
|
|.................................................... | 80%
|
|..................................................... | 82%
|
|...................................................... | 84%
|
|....................................................... | 85%
|
|........................................................ | 87%
|
|.......................................................... | 89%
|
|........................................................... | 90%
|
|............................................................ | 92%
|
|............................................................. | 93%
|
|.............................................................. | 95%
|
|............................................................... | 97%
|
|................................................................ | 98%
|
|.................................................................| 100%
output file: 10-import-data.R
unable to open connection to X11 display ''
processing file: 20-explore.Rmd
|
| | 0%
|
|... | 5%
|
|...... | 10%
|
|......... | 14%
|
|............ | 19%
|
|............... | 24%
|
|................... | 29%
|
|...................... | 33%
|
|......................... | 38%
|
|............................ | 43%
|
|............................... | 48%
|
|.................................. | 52%
|
|..................................... | 57%
|
|........................................ | 62%
|
|........................................... | 67%
|
|.............................................. | 71%
|
|.................................................. | 76%
|
|..................................................... | 81%
|
|........................................................ | 86%
|
|........................................................... | 90%
|
|.............................................................. | 95%
|
|.................................................................| 100%
output file: 20-explore.R
unable to open connection to X11 display ''
processing file: 30-feature-engineering.Rmd
|
| | 0%
|
|................................ | 50%
|
|.................................................................| 100%
output file: 30-feature-engineering.R
unable to open connection to X11 display ''
processing file: 40-feature-selection.Rmd
|
| | 0%
|
|.... | 7%
|
|......... | 13%
|
|............. | 20%
|
|................. | 27%
|
|...................... | 33%
|
|.......................... | 40%
|
|.............................. | 47%
|
|................................... | 53%
|
|....................................... | 60%
|
|........................................... | 67%
|
|................................................ | 73%
|
|.................................................... | 80%
|
|........................................................ | 87%
|
|............................................................. | 93%
|
|.................................................................| 100%
output file: 40-feature-selection.R
unable to open connection to X11 display ''
processing file: 50-modeling.Rmd
|
| | 0%
|
|.. | 3%
|
|.... | 6%
|
|...... | 9%
|
|........ | 12%
|
|.......... | 15%
|
|............ | 18%
|
|.............. | 21%
|
|................ | 24%
|
|.................. | 27%
|
|.................... | 30%
|
|...................... | 33%
|
|........................ | 36%
|
|.......................... | 39%
|
|............................ | 42%
|
|.............................. | 45%
|
|................................ | 48%
|
|................................. | 52%
|
|................................... | 55%
|
|..................................... | 58%
|
|....................................... | 61%
|
|......................................... | 64%
|
|........................................... | 67%
|
|............................................. | 70%
|
|............................................... | 73%
|
|................................................. | 76%
|
|................................................... | 79%
|
|..................................................... | 82%
|
|....................................................... | 85%
|
|......................................................... | 88%
|
|........................................................... | 91%
|
|............................................................. | 94%
|
|............................................................... | 97%
|
|.................................................................| 100%
output file: 50-modeling.R
unable to open connection to X11 display ''
processing file: 60-results.Rmd
|
| | 0%
|
|..... | 8%
|
|.......... | 15%
|
|............... | 23%
|
|.................... | 31%
|
|......................... | 38%
|
|.............................. | 46%
|
|................................... | 54%
|
|........................................ | 62%
|
|............................................. | 69%
|
|.................................................. | 77%
|
|....................................................... | 85%
|
|............................................................ | 92%
|
|.................................................................| 100%
output file: 60-results.R
Loading required package: pacman
the condition has length > 1 and only the first element will be usedthe condition has length > 1 and only the first element will be usedthe condition has length > 1 and only the first element will be usedInstalling packages into ‘/home/yaqooba/R/x86_64-pc-linux-gnu-library/3.6’
(as ‘lib’ is unspecified)
package ‘c’ is not available (for R version 3.6.0)trying URL 'https://cran.rstudio.com/src/contrib/plotly_4.9.2.1.tar.gz'
Content type 'application/x-gzip' length 3709741 bytes (3.5 MB)
==================================================
downloaded 3.5 MB
trying URL 'https://cran.rstudio.com/src/contrib/dotwhisker_0.5.0.tar.gz'
Content type 'application/x-gzip' length 935078 bytes (913 KB)
==================================================
downloaded 913 KB
trying URL 'https://cran.rstudio.com/src/contrib/broom_0.7.0.tar.gz'
Content type 'application/x-gzip' length 604195 bytes (590 KB)
==================================================
downloaded 590 KB
* installing *source* package ‘plotly’ ...
** package ‘plotly’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** data
*** moving datasets to lazyload DB
** demo
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (plotly)
* installing *source* package ‘broom’ ...
** package ‘broom’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (broom)
* installing *source* package ‘dotwhisker’ ...
** package ‘dotwhisker’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (dotwhisker)
The downloaded source packages are in
‘/tmp/Rtmp7eQEhF/downloaded_packages’
'BiocManager' not available. Could not check Bioconductor.
Please use `install.packages('BiocManager')` and then retry.Installing package into ‘/home/yaqooba/R/x86_64-pc-linux-gnu-library/3.6’
(as ‘lib’ is unspecified)
trying URL 'https://cran.rstudio.com/src/contrib/plotly_4.9.2.1.tar.gz'
Content type 'application/x-gzip' length 3709741 bytes (3.5 MB)
==================================================
downloaded 3.5 MB
* installing *source* package ‘plotly’ ...
** package ‘plotly’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** data
*** moving datasets to lazyload DB
** demo
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (plotly)
The downloaded source packages are in
‘/tmp/Rtmp7eQEhF/downloaded_packages’
plotly installed
Purpose. In this work, we will explore the relation between identified measures of despair of interest (e.g., personality measures of self-consciousness, individual and composite item scores from the CES-D assessment) and descriptors of diseases of despair. We will achieve this goal through modeling the outcomes based on the included predictors, and robustly assess the importance of the included features in predicting the outcomes via bootstrapping. We will use two well-known machine learning models, random forests and LASSO, which are both frequently used to measure the relative importance of the predictors included in the models. Lastly, we’ll generate trained and tuned models using this reduced feature set which can be used by others wish to predict the identified outcomes.
Subject inclusion. For this investigation, we will omit the entirety of Wave 2. This is commonly done in analyses of AddHealth data due the design of the original study. Otherwise, our dataset will include only subjects who have predictor and outcome data in all of the waves.
Outcome variables. In this experiment, we assess prescription drug use at Wave 5.
Predictor variables. The predictors for these models are hand-picked, and based on previous work, relevance, and subject matter expertise. The set of predictors and the set of outcomes are disjoint. Predictors from Waves 1-4 (excluding Wave 2, see above) are included, and will be detailed in the following analysis.
seed= 3895
set.seed(seed)
The predictors we will be using will be the the variable predictor_list loaded from 10-import-data.Rmd file. These initial set of predictors will be based of the list of variables that describe: 1. anxiety 2. depression 3. optimism
full_dataset %>%
group_by(p_drug) %>%
summarise(total = n(), type = class(p_drug))
`summarise()` ungrouping output (override with `.groups` argument)
## get the aids that you want
inner_aids <- get_inner(list(wave_data[[1]], wave_data[[3]], wave_data[[4]], wave_data[[5]]))
## get na_levels : dataset to recode all skip levels in variables
na_levels <- read_csv("na_levels.csv")
## use the features and ids that you want to select out what you want
pr_drug_ds <- full_dataset %>%
add_demographics() %>%
add_bio_despair() %>%
dplyr::select(aid, predictor_list, demographic_age_list, demographic_list, outcome) %>%
recode_missing_levels(na_levels) %>%
filter(aid %in% inner_aids) %>%
remove_subjects_not_in_wave1() %>%
fix_outcome_variable("p_drug") %>%
mutate(p_drug = as.factor(p_drug)) %>%
select(-c(h5waist,h5bmi,h5dbp,h5bpjcls,h5bpcls4,h5sbp))
[1] "Recoding Missing Factor Variables"
[1] "Factor variables being recoded : 65"
[1] "Recoding Missing Numeric Variables"
[1] "Numeric variables being recoded : 10"
# full_dataset %>%
# add_pres_drug %>%
# add_demographics() %>%
# dplyr::select(aid, predictor_list, demographic_age_list, demographic_list, outcome) %>%
# filter(aid %in% inner_aids) %>%
# remove_subjects_not_in_wave1(filebase='Z:') %>%
# mutate_at(vars(-starts_with("age_")), as.factor) %>%
# mutate_at(vars(c(-starts_with("age_"),outcome)), fct_explicit_na) %>%
# drop_na(outcome)
# Report about the characteristics of the subjects left out of the join
count_not_joined(wave_data = wave_data, number_waves_joined = 5)
# Validate the generated dataset using asserts
pr_drug_ds %>%
group_by(p_drug) %>%
summarise(total = n(), type = class(p_drug))
`summarise()` ungrouping output (override with `.groups` argument)
Here, we comment about the general characteristics of the data based on the provided visualizations. We comment on missingness of data, any strange or unusual behavior (e.g., strong imbalances), and any correlation that sticks out.
# Visualize distributions of variables of interest
pr_drug_ds %>%
dplyr::select(-aid) %>%
graph_bar_discrete(df = .,
plot_title = "Distributions of Discrete Variables",
max_categories = 50,
num_rows = 3,
num_cols = 3,
x_axis_size = 12,
y_axis_size = 12,
title_size = 15)
# Visualize missingness
graph_missing(pr_drug_ds,
only_missing = TRUE,
title = "Percent Missing",
box_line_size = .5,
label_size = .5,
x_axis_size = 12,
y_axis_size = 12,
title_size = 15)
# Visualize correlation among first 20 predictors
pr_drug_ds %>%
dplyr::select(1:20) %>%
pairwise_cramers_v() %>%
plot_cramer_v(x_axis_angle = 90,
plot_title = "Association among Categorical Variables",
interactive = TRUE)
full_dataset %>%
filter(aid %in% inner_aids) %>%
add_wave_4_lipids() %>%
group_by(hdl) %>%
summarise(total = n(), type = class(hdl))
`summarise()` ungrouping output (override with `.groups` argument)
pr_drug_ds %>%
group_by(hdl) %>%
summarise(total = n(), type = class(hdl))
`summarise()` ungrouping output (override with `.groups` argument)
In this section, we split the data to ensure that our model is able to generalize to other datasets.
## split the data into relevant proportions desired
data_splits <- pr_drug_ds %>%
split_data(strat_var = outcome, ratios=c(0.7, 0.2, 0.1))
# assemble list
training_df <- data_splits$train
validation_df <- data_splits$valid
testing_df <- data_splits$test
The RF models are chosen based on a grid search using the following the parameters:
The following table displays the mean performance metrics for the bootstrapped models on the validation set, removing values for which there are NA.
mean_bs_rf_perf <- get_metric_set_from_perfs(pr_drug_rf$perfs) %>%
dplyr::select(accuracy, mpce, sens, spec, ppv, npv, roc_auc, pr_auc,
tns, tps, fns, fps, no_n, no_p, err_rate, bal_accuracy, everything()) %>%
summarise_if(is.numeric, mean, na.rm=TRUE) %>%
mutate(model = 'bs_rf') %>%
dplyr::select(model, everything())
mean_bs_rf_perf
As shown, the bootstrapped models tend to have high specificity but low sensitivity, indicating that there is a challenge in identifying subjects with suicidal ideation.
boot_rf_mdi <- pr_drug_rf$mdi %>%
get_median_placement(use_base_var = TRUE) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, att_name, overall_rank)
`summarise()` ungrouping output (override with `.groups` argument)
head(boot_rf_mdi, 20)
This table returns the MDI variable importance ranks that returned from each of the bootstrapped models.
# Needs to be fixed so that axes don't overlap each other and obscure understanding
plot_placement_boxplot(pr_drug_rf$mdi)
Now, let’s look at the permutation importance:
boot_rf_perm_plt <- pr_drug_rf$models %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
met <- 'pr_auc'
boot_rf_perm <- boot_rf_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(boot_rf_perm, 20)
In this step, we assess the differences generated between the different types of importances.
cbind(boot_rf_mdi[1:20,], dplyr::select(boot_rf_perm[1:20,], -all_of(met)))
As shown, the MDI importance suffers from imbalances due to the number of values associated with a predictor. Because the wave ages have so many more values than the other factors, this artificially inflates their importance in MDI. The permutation importance is more intuitive.
plot_permute_var_imp(boot_rf_perm, metric = pr_auc)
In this step, we model the relation between the outcomes and the predictors using a linear regression with L2 regularization. This drives the importance of unimportant and redudant features towards zero.
# Function parameters
lasso_params <- list(alpha = c(1))
# Call modeling function using function parameters and show visualization of results. Recommend the number of features that should be used. Report performance metric stats.
pr_drug_lasso <- model_feature_selection( "Lasso",
training_frame = training_df,
validation_frame = validation_df,
hyper_params = lasso_params,
outcome = outcome,
n = n_boot)
mean_bs_lasso_perf <- get_metric_set_from_perfs(pr_drug_lasso$perfs) %>%
dplyr::select(accuracy, mpce, sens, spec, ppv, npv, roc_auc, pr_auc,
tns, tps, fns, fps, no_n, no_p, err_rate, bal_accuracy, everything()) %>%
summarise_if(is.numeric, mean, na.rm=TRUE) %>%
mutate(model='bs_lasso') %>%
dplyr::select(model, everything())
mean_bs_lasso_perf
boot_lasso_mdi <- pr_drug_lasso$mdi %>%
get_median_placement(use_base_var = TRUE) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, att_name, overall_rank)
`summarise()` ungrouping output (override with `.groups` argument)
head(boot_lasso_mdi, 20)
plot_placement_boxplot(pr_drug_lasso$mdi)
boot_lasso_perm_plt <- pr_drug_lasso$models %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
boot_lasso_perm <- boot_lasso_perm_plt %>%
get_permute_placement(metric_oi=met) %>% #set in random forest section
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(boot_lasso_perm, 20)
plot_permute_var_imp(boot_lasso_perm, metric = pr_auc)
Now, we compare the feature importances generated by the two different approaches. The traditional method of evaluating feature importance for regression methods is through analysis of the coefficients.
cbind(boot_lasso_mdi[1:20,], dplyr::select(boot_lasso_perm[1:20,], -met))
Note: Using an external vector in selections is ambiguous.
[34mℹ[39m Use `all_of(met)` instead of `met` to silence this message.
[34mℹ[39m See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
[90mThis message is displayed once per session.[39m
The following table compares the mean performance of bootstrapped random forests to the mean performance of bootstrapped LASSO methods.
bs_comp_perfs <- rbind(mean_bs_rf_perf, mean_bs_lasso_perf)
bs_comp_perfs
Here, we look at the aggregated results of the bootstrapped predictors and compare the models generated to each other.
joined_results <- boot_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.rf', '.lasso')) %>%
mutate(mean_rank = (overall_rank.rf+overall_rank.lasso)/2) %>%
arrange(mean_rank)
head(joined_results, 20)
The following visualization provides the intuition about the differences in the rankings between model types. They’re ordered by the overall mean importance, and for a given variable, the differences in rank are shown.
# Comparison of top_n features
joined_results %>%
compare_feature_select(interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors by Model")
Note: Using an external vector in selections is ambiguous.
[34mℹ[39m Use `all_of(sel_cols)` instead of `sel_cols` to silence this message.
[34mℹ[39m See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
[90mThis message is displayed once per session.[39m
`group_by_()` is deprecated as of dplyr 0.7.0.
Please use `group_by()` instead.
See vignette('programming') for more help
[90mThis warning is displayed once every 8 hours.[39m
[90mCall `lifecycle::last_warnings()` to see where this warning was generated.[39m
In this step, we build the final model for the random forest. We use slightly more values in order to come up with the best model, keeping in mind the number of combinations that are required to run to evaluate the grid.
# # Spans of hyper parameters for random forest
rf_params <- list(max_depth = 50,
ntrees = 150,
mtries = seq(-1, 30, by=5),
min_rows = seq(5, 60, by=5),
balance_classes = c(TRUE, FALSE),
stopping_metric = 'AUCPR',
categorical_encoding = 'one_hot_explicit')
# rf_params <- list(max_depth = seq(20, 50, 20),
# balance_classes = TRUE,
# categorical_encoding= 'one_hot_explicit')
# Function parameters
final_model_rf <- rf_model(outcome,
training_frame = training_df,
validation_frame = validation_df,
nfolds = 5,
hyper_params = rf_params, model_seed=seed)
`aid` is in your training data frame. Dropping for modeling purposes.
`aid` is in your validation data frame. Dropping for modeling purposes.
`summarise()` ungrouping output (override with `.groups` argument)
The final random forest performance metrics are shown below:
# show model final performance
print(final_model_rf[[2]])
final_rf_perm_plt <- c(final_model_rf[[1]]) %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
final_rf_perm <- final_rf_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(final_rf_perm, 20)
plot_permute_var_imp(final_rf_perm, metric = pr_auc)
This section investigates the differences in the bootstrap results vs the features generated from the random forest final model. The following table shows the overall differences in rank.
rf_joined_results <- final_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_rf_perm, -met), by=c("predictor", "att_name"), suffix=c('.final', '.bootstrap')) %>%
mutate(mean_rank = (overall_rank.final + overall_rank.bootstrap)/2) %>%
arrange(mean_rank)
head(rf_joined_results, 20)
The following plot provides visualizations for the difference in the final model rankings vs the bootstrap.
# Comparison of top_n features
rf_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.final", "overall_rank.bootstrap"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Final vs. Bootstrap")
Now, we create the final model for LASSO. There is no substantial difference between this method and the bootstrap methods, other than the data upon which the model is being built.
# Function parameters
lasso_params <- list(alpha = c(1))
final_model_lasso <- lasso_model(training_frame = training_df,
validation_frame = validation_df,
outcome = outcome,
nfolds = 5,
hyper_params = lasso_params)
The final LASSO performance metrics are shown below:
# show model final performance
print(final_model_lasso[[2]])
final_lasso_perm_plt <- c(final_model_lasso[[1]]) %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
final_lasso_perm <- final_lasso_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(final_lasso_perm, 20)
plot_permute_var_imp(final_lasso_perm, metric = pr_auc)
This section investigates the differences in the bootstrap results vs the features generated from the LASSO final model. The following table shows the overall differences in rank.
lasso_joined_results <- final_lasso_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.final', '.bootstrap')) %>%
mutate(mean_rank = (overall_rank.final + overall_rank.bootstrap)/2) %>%
arrange(mean_rank)
head(lasso_joined_results, 20)
The following plot provides visualizations for the difference in the final model rankings vs the bootstrap.
# Comparison of top_n features
lasso_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.final", "overall_rank.bootstrap"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Final vs. Bootstrap")
Here, we compare the features generated by the permutation importance between the two final models.
rf_lasso_final_joined_results <- final_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(final_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.rf', '.lasso')) %>%
mutate(mean_rank = (overall_rank.rf+overall_rank.lasso)/2) %>%
arrange(mean_rank)
head(rf_lasso_final_joined_results, 20)
The following visualization provides the intuition about the differences in the rankings between the final model types. They’re ordered by the overall mean importance, and for a given variable, the differences in rank are shown.
# Comparison of top_n features
rf_lasso_final_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.rf", "overall_rank.lasso"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Random Forest vs Lasso")
With the final models generated, we’re now able to compare their performance metrics.
# Comparison of performance metrics
valid_perf <- get_metric_set_from_perfs(perf_list = list(final_model_rf[[2]], final_model_lasso[[2]])) %>%
mutate(model = c('rf', 'lasso'))
testing_perf <- get_metric_set_from_models(testing_df, list(final_model_rf[[1]], final_model_lasso[[1]]), out=outcome) %>%
mutate(model = c('rf', 'lasso'))
`summarise()` ungrouping output (override with `.groups` argument)
`summarise()` ungrouping output (override with `.groups` argument)
Validation and selection. The following table shows the comparison between models in terms of the validation set. We can select our final model based on the best performing model according to the metric.
print(valid_perf)
Testing performance. The following shows the performance of both the models on the test set. Note that although we don’t use this test set to evaluate the final models, we can still see how our selected method would have performed.
print(testing_perf)
The following plots show a comparison between the performance of the models on the validation and test sets. Again, we don’t choose the model based on the test set, but curiosity dictates that we view this performance.
# Show plots side by side
metrics_of_interest = c('model', 'accuracy', 'bal_accuracy', 'mpce', 'sens', 'spec', 'ppv', 'npv', 'pr_auc', 'roc_auc')
valid_plt <- plot_metric_set(dplyr::select(valid_perf, all_of(metrics_of_interest)), plot_title = "Model comparison for validation set")
test_plt <- plot_metric_set(dplyr::select(testing_perf, all_of(metrics_of_interest)), plot_title = "Model comparison for testing set")
gridExtra::grid.arrange(gridExtra::arrangeGrob(valid_plt, test_plt, ncol=2, nrow=1))
Here, the subject matter experts will comment on the the differences in the features obtained between the studied outcomes variables and discuss the discrepancies and/or cohesion.
# Show differences in features obtained